#!/usr/bin/env python3

import argparse
import json
import os

from tqdm import tqdm

from src.generation_methods import decode_consensus
from src.new_text_alignment import TextSeqGraphAlignment
from src.text_poa_graph import TextPOAGraph


def process_bios(
    results_dir,
    output_dir,
    model_name,
    short_task_code,
    clean_up_api,
    clean_up_model,
    consensus_texts_output_dir,
    consensus_threshold,
    num_samples=250,
):
    """
    Process bios and create POA graphs for the given model and task.

    Args:
        results_dir (str): Directory containing the input JSON files
        output_dir (str): Directory to save the output graphs
        model_name (str): Name of the model (e.g., llama70b, llama8b, olmo7b)
        short_task_code (str): Task code (e.g., bio)
        num_samples (int): Number of samples to process
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Load bios from JSON file
    input_file = f"{results_dir}/{model_name}/{short_task_code}/5_samples_halogen_{short_task_code}_gen_{model_name}.json"
    with open(input_file, "r") as f:
        bios = json.load(f)

    consensus_texts = []

    # Process each bio
    for j, item in tqdm(enumerate(bios[:num_samples])):
        entity = item["Prompt"]

        # Skip if output already exists
        if short_task_code == "bio":
            output_file = (
                f"{short_task_code}/{short_task_code}_graph_{entity}_merged_{model_name}.pkl"
            )
        else:
            output_file = f"{short_task_code}/{short_task_code}_graph_{j}_merged_{model_name}.pkl"

        if output_file in os.listdir(output_dir):
            print(f"Skipping graph construction for {j}")
            graph = TextPOAGraph.load_from_pickle(f"{output_dir}/{output_file}")
            consensus_texts.append(
                {
                    "Prompt": entity,
                    "Responses": decode_consensus(
                        graph,
                        api=clean_up_api,
                        model=clean_up_model,
                        selection_threshold=consensus_threshold,
                    ),
                }
            )
            continue

        entity_bios_text = item["Responses"]
        print(f"Processing entity: {entity}")
        print("Calculating graphs")

        # Create initial graph
        graph = TextPOAGraph(entity_bios_text[0], label=0)

        # Incorporate remaining samples
        for i, sample in enumerate(entity_bios_text[1:]):
            alignment = TextSeqGraphAlignment(
                text=sample,
                graph=graph,
                fastMethod=True,
                globalAlign=True,
                matchscore=1,
                mismatchscore=-2,
                gap_open=-1,
            )
            graph.incorporateSeqAlignment(alignment, sample, label=i + 1)

        # Refine graph
        graph.refine_graph(verbose=False, domain="text", model="gpt-4.1-mini")

        # Save outputs
        output_base = os.path.join(output_dir, output_file.replace(".pkl", ""))

        # Save HTML visualization
        with open(f"{output_base}.html", "w+") as f:
            graph.htmlOutput(f, annotate_consensus=False)

        # Save pickle file
        graph.save_to_pickle(f"{output_base}.pkl")

        consensus_texts.append(
            {
                "Prompt": entity,
                "Responses": decode_consensus(
                    graph,
                    api=clean_up_api,
                    model=clean_up_model,
                    selection_threshold=consensus_threshold,
                ),
            }
        )

    with open(
        f"{consensus_texts_output_dir}/{short_task_code}_consensus_texts_{model_name}.json", "w+"
    ) as f:
        json.dump(consensus_texts, f)


def main():
    parser = argparse.ArgumentParser(description="Process bios and create POA graphs")
    parser.add_argument(
        "--results-dir",
        type=str,
        required=True,
        help="Directory containing the input JSON files",
        default="../results/consensus_texts/vp_data/",
    )
    parser.add_argument(
        "--graphs-output-dir",
        type=str,
        required=True,
        help="Directory to save the output graphs",
        default="../results/consensus_texts/vp_data/graphs/",
    )
    parser.add_argument(
        "--consensus-texts-output-dir",
        type=str,
        required=True,
        help="Directory to save the consensus texts",
        default="../results/consensus_texts/HALoGEN/",
    )
    parser.add_argument(
        "--consensus-threshold", type=float, default=0.5, help="Consensus threshold (default: 0.5)"
    )
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        choices=["llama70b", "llama8b", "olmo7b", "qwen72b"],
        help="Model to process (llama70b, llama8b, olmo7b, or qwen72b)",
    )
    parser.add_argument(
        "--clean-up-api",
        type=str,
        default="openai",
        help="API to use for clean-up (default: openai)",
    )
    parser.add_argument(
        "--clean-up-model",
        type=str,
        default="gpt-4o-mini",
        help="Model to use for clean-up (default: gpt-4o-mini)",
    )
    parser.add_argument("--task", type=str, default="bio", help="Task code (default: bio)")
    parser.add_argument(
        "--num-samples", type=int, default=250, help="Number of samples to process (default: 250)"
    )

    args = parser.parse_args()

    process_bios(
        results_dir=args.results_dir,
        output_dir=args.output_dir,
        model_name=args.model,
        short_task_code=args.task,
        num_samples=args.num_samples,
        clean_up_api=args.clean_up_api,
        clean_up_model=args.clean_up_model,
        consensus_texts_output_dir=args.consensus_texts_output_dir,
        consensus_threshold=args.consensus_threshold,
    )


if __name__ == "__main__":
    main()
